# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import torch.nn as nn

from msmodelslim.core.const import DeviceType
from msmodelslim.core.graph import AdapterConfig
from msmodelslim.model.qwen3_moe.model_adapter import Qwen3MoeModelAdapter


class DummyConfig:
    """模拟配置对象"""

    def __init__(self):
        self.num_hidden_layers = 3


class TestQwen3MoeModelAdapter(unittest.TestCase):
    """测试Qwen3MoeModelAdapter的功能"""

    def setUp(self):
        """测试前的准备工作"""
        self.model_type = 'Qwen3-30B'
        self.model_path = Path('.')
        self.trust_remote_code = False

    def test_get_model_type_when_initialized_then_return_model_type(self):
        """测试get_model_type方法：初始化后应返回正确的模型类型"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path,
                trust_remote_code=self.trust_remote_code
            )
            adapter.model_type = self.model_type

            result = adapter.get_model_type()

            self.assertEqual(result, self.model_type)

    def test_get_model_pedigree_when_called_then_return_qwen3_moe(self):
        """测试get_model_pedigree方法：应返回'qwen3_moe'"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            result = adapter.get_model_pedigree()

            self.assertEqual(result, 'qwen3_moe')

    def test_load_model_when_called_then_delegate_to_load_model(self):
        """测试load_model方法：应委托给_load_model方法"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_model = nn.Linear(10, 10)
            adapter._load_model = MagicMock(return_value=mock_model)

            result = adapter.load_model(device=DeviceType.NPU)

            self.assertIs(result, mock_model)
            adapter._load_model.assert_called_once_with(DeviceType.NPU)

    def test_handle_dataset_when_called_then_return_tokenized_data(self):
        """测试handle_dataset方法：应返回tokenized数据"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_dataset = ['data1', 'data2']
            adapter._get_tokenized_data = MagicMock(return_value=mock_dataset)

            result = adapter.handle_dataset(dataset='test_data', device=DeviceType.CPU)

            self.assertEqual(result, mock_dataset)
            adapter._get_tokenized_data.assert_called_once_with('test_data', DeviceType.CPU)

    def test_handle_dataset_by_batch_when_called_then_return_batch_tokenized_data(self):
        """测试handle_dataset_by_batch方法：应返回批量tokenized数据"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_batch_dataset = [['batch1'], ['batch2']]
            adapter._get_batch_tokenized_data = MagicMock(return_value=mock_batch_dataset)

            result = adapter.handle_dataset_by_batch(
                dataset='test_data',
                batch_size=2,
                device=DeviceType.CPU
            )

            self.assertEqual(result, mock_batch_dataset)
            adapter._get_batch_tokenized_data.assert_called_once_with(
                calib_list='test_data',
                batch_size=2,
                device=DeviceType.CPU
            )

    def test_init_model_when_called_then_delegate_to_load_model(self):
        """测试init_model方法：应委托给_load_model方法"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_model = nn.Linear(10, 10)
            adapter._load_model = MagicMock(return_value=mock_model)

            result = adapter.init_model(device=DeviceType.NPU)

            self.assertIs(result, mock_model)
            adapter._load_model.assert_called_once_with(DeviceType.NPU)

    def test_enable_kv_cache_when_called_then_delegate_to_enable_kv_cache(self):
        """测试enable_kv_cache方法：应委托给_enable_kv_cache方法"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_model = nn.Linear(10, 10)
            adapter._enable_kv_cache = MagicMock(return_value=None)

            result = adapter.enable_kv_cache(model=mock_model, need_kv_cache=True)

            # 验证_enable_kv_cache被调用
            adapter._enable_kv_cache.assert_called_once_with(mock_model, True)

    def test_enable_kv_cache_with_false_then_disable_cache(self):
        """测试enable_kv_cache方法：传入False时应禁用缓存"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )

            mock_model = nn.Linear(10, 10)
            adapter._enable_kv_cache = MagicMock(return_value=None)

            adapter.enable_kv_cache(model=mock_model, need_kv_cache=False)

            # 验证参数正确传递
            adapter._enable_kv_cache.assert_called_once_with(mock_model, False)

    def test_get_adapter_config_for_subgraph_when_called_then_return_adapter_configs(self):
        """测试get_adapter_config_for_subgraph方法：应返回适配器配置列表"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )
            adapter.config = DummyConfig()

            result = adapter.get_adapter_config_for_subgraph()

            # 验证返回列表
            self.assertIsInstance(result, list)

            # 每层应该有2个配置（Norm-Linear融合 + OV融合）
            expected_configs = adapter.config.num_hidden_layers * 2
            self.assertEqual(len(result), expected_configs)

            # 验证第一个配置的类型
            self.assertIsInstance(result[0], AdapterConfig)

    def test_get_adapter_config_for_subgraph_when_zero_layers_then_return_empty_list(self):
        """测试get_adapter_config_for_subgraph方法：0层时应返回空列表"""
        with patch('msmodelslim.model.qwen3_moe.model_adapter.TransformersModel.__init__', return_value=None):
            adapter = Qwen3MoeModelAdapter(
                model_type=self.model_type,
                model_path=self.model_path
            )
            adapter.config = DummyConfig()
            adapter.config.num_hidden_layers = 0

            result = adapter.get_adapter_config_for_subgraph()

            # 验证返回空列表
            self.assertEqual(len(result), 0)
            self.assertIsInstance(result, list)


class TestQwen3MoeModuleFunctions(unittest.TestCase):

    def setUp(self):
        """测试前的准备工作"""
        self.config = DummyConfig()
        # 为测试函数添加必要的配置属性
        self.config.num_experts = 4  # 添加专家数量
        self.config.hidden_size = 128  # 添加隐藏层大小
        self.config.head_dim = 64  # 添加头维度

    def test_qwen3_moe_get_ln_fuse_map_when_called_then_return_correct_mapping(self):
        """测试qwen3_moe_get_ln_fuse_map方法：调用时应返回正确的融合映射"""
        from msmodelslim.model.qwen3_moe.model_adapter import qwen3_moe_get_ln_fuse_map

        result = qwen3_moe_get_ln_fuse_map(self.config)

        # 验证返回字典
        self.assertIsInstance(result, dict)

        # 验证包含必要的键
        self.assertIn("model.norm", result)
        self.assertIn("model.layers.0.input_layernorm", result)
        self.assertIn("model.layers.0.post_attention_layernorm", result)

        # 验证input_layernorm的映射
        input_ln_targets = result["model.layers.0.input_layernorm"]
        expected_input_targets = [
            "model.layers.0.self_attn.q_proj",
            "model.layers.0.self_attn.k_proj",
            "model.layers.0.self_attn.v_proj"
        ]
        for target in expected_input_targets:
            self.assertIn(target, input_ln_targets)

        # 验证post_attention_layernorm的映射包含专家和gate
        post_ln_targets = result["model.layers.0.post_attention_layernorm"]
        self.assertIn("model.layers.0.mlp.gate", post_ln_targets)

        # 验证专家数量正确
        expert_targets = [t for t in post_ln_targets if "experts" in t]
        self.assertEqual(len(expert_targets), self.config.num_experts * 2)  # gate_proj + up_proj

    def test_qwen3_moe_get_rotate_map_when_called_then_return_correct_rotate_commands(self):
        """测试qwen3_moe_get_rotate_map方法：调用时应返回正确的旋转命令"""
        from msmodelslim.model.qwen3_moe.model_adapter import qwen3_moe_get_rotate_map
        from msmodelslim.quant.processor.quarot import QuaRotInterface

        block_size = 64
        pre_run, rot_pairs, rot, rot_uv = qwen3_moe_get_rotate_map(self.config, block_size)

        # 验证返回类型
        self.assertIsInstance(pre_run, QuaRotInterface.RotatePair)
        self.assertIsInstance(rot_pairs, dict)

        # 验证pre_run包含embed_tokens
        self.assertIn("model.embed_tokens", pre_run.right_rot)

        # 验证rot_pairs包含rot和rot_uv
        self.assertIn("rot", rot_pairs)
        self.assertIn("rot_uv", rot_pairs)

        # 验证rot包含正确的层数
        rot_pair = rot_pairs["rot"]
        self.assertIn("lm_head", rot_pair.right_rot)

        # 验证包含self_attn层
        for layer_idx in range(self.config.num_hidden_layers):
            for proj in ["q_proj", "k_proj", "v_proj"]:
                key = f"model.layers.{layer_idx}.self_attn.{proj}"
                self.assertIn(key, rot_pair.right_rot)

            # 验证o_proj在left_rot中
            o_proj_key = f"model.layers.{layer_idx}.self_attn.o_proj"
            self.assertIn(o_proj_key, rot_pair.left_rot)

            # 验证专家层
            for i in range(self.config.num_experts):
                for proj in ["gate_proj", "up_proj"]:
                    key = f"model.layers.{layer_idx}.mlp.experts.{i}.{proj}"
                    self.assertIn(key, rot_pair.right_rot)

                down_proj_key = f"model.layers.{layer_idx}.mlp.experts.{i}.down_proj"
                self.assertIn(down_proj_key, rot_pair.left_rot)

            # 验证gate
            gate_key = f"model.layers.{layer_idx}.mlp.gate"
            self.assertIn(gate_key, rot_pair.right_rot)


if __name__ == '__main__':
    unittest.main()
